import numpy as np

class CANN_2D:
    def __init__(self,N,k,m,J,tau,tau_v,trans):
        self.length = N
        self.k = k
        self.m = m
        self.J = J
        self.tau = tau
        self.tau_v = tau_v
        self.rho = N**2/(2*np.pi)**2
        self.U_bump = np.zeros([N,N])
        self.V_bump = np.zeros_like(self.U_bump)
        self.r = np.zeros_like(self.U_bump)
        self.trans = trans

    def interact(self):
        if self.trans:
            rfft = np.fft.fft2(self.r)
            jfft = np.fft.fft2(self.J)
            U_tmp = np.fft.ifft2(rfft*jfft)
        else:
            U_tmp = np.matmul(self.J.reshape(self.width*self.length,self.width*self.length),self.r.reshape(self.width*self.length)).reshape(self.width,self.length)
        return np.real(U_tmp)

    def update(self,I_ext,dt):
        self.r = self.U_bump**2/(1+self.k*(self.U_bump**2).sum())
        self.U_bump = self.U_bump + dt/self.tau*(self.interact() + I_ext - self.U_bump - self.V_bump)#*(1+0.9*np.random.standard_normal([self.length,self.length])))
        self.V_bump = self.V_bump + dt/self.tau_v*(self.m*self.U_bump - self.V_bump) 
        lzero = self.U_bump<0
        self.U_bump[lzero] = 0

    def reset(self):
        self.U_bump = np.zeros(self.length)
        self.V_bump = np.zeros_like(self.U_bump)
        self.r = np.zeros_like(self.U_bump)

    def set(self,x):
        if x.shape == self.U_bump.shape:
            self.U_bump = x

